import numpy as np
from scipy.stats import norm
import scipy
import torch
from diffusers.utils.torch_utils import randn_tensor


class GTWatermarkCachedLearnedWM:
    def __init__(self, device, watermarking_mask, w_channel=3, w_radius=10, wm_pattern='rings', shape='circle', generator=None):
        self.w_channel = w_channel
        self.w_radius = w_radius
        self.width = 10
        self.wm_pattern = wm_pattern
        self.device = device
        self.shape = shape

        # if wm_pattern == 'rings':
        #     self.gt_patch = gt_patch.to(device)
        # else:
        if w_radius != 10:
            gt_init = randn_tensor(watermarking_mask.shape, generator=generator, device=self.device, dtype=torch.float32)
            self.gt_patch = self._get_watermarking_pattern(gt_init).to(device)
        else:
            self.gt_patch = None

        if shape=='circle':
            self.watermarking_mask = self._get_watermarking_mask(watermarking_mask)
        else:
            self.watermarking_mask = self._get_watermarking_mask(watermarking_mask)

    def inject_watermark(self, gt_patch, latents): 
        latents_fft = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))
        # latents_fft[self.watermarking_mask] = self.gt_patch[self.watermarking_mask].clone()
        latents_fft = latents_fft * ~(self.watermarking_mask) + gt_patch * self.watermarking_mask
        latents_w = torch.fft.ifft2(torch.fft.ifftshift(latents_fft, dim=(-1, -2))).real
        return latents_w

    # FIXME: Only keeping this here to avoid compilation issues
    # the probability of being watermarked
    def one_minus_p_value(self, latents):
        raise NotImplementedError()
    
    def _circle_mask(self, size=64, r=10, x_offset=0, y_offset=0):
    # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3
        x0 = y0 = size // 2
        x0 += x_offset
        y0 += y_offset
        y, x = np.ogrid[:size, :size]
        y = y[::-1]
        return ((x - x0)**2 + (y-y0)**2)<= r**2
    
    def _square_mask(self, size=64, side_length=12, x_offset=0, y_offset=0):
        # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3
        x0 = y0 = size // 2
        x0 += x_offset
        y0 += y_offset
        half_side = side_length // 2
        y, x = np.ogrid[:size, :size]
        y = y[::-1]

        return (np.abs(x - x0) < half_side) & (np.abs(y - y0) < half_side)
    def _triangle(self, size=19, side_length=7, x_offset=0, y_offset=0):
        x0 = size // 2 + x_offset
        y0 = size // 2 + y_offset
        height = np.sqrt(3) / 2 * side_length  # Height of an equilateral triangle

        y, x = np.ogrid[:size, :size]
        y = y[::-1]  # Flip the y-axis to match coordinate orientation

        # Compute the vertices of the equilateral triangle
        vertices = np.array([
            [x0, y0 + height / 2],
            [x0 - side_length / 2, y0 - height / 2],
            [x0 + side_length / 2, y0 - height / 2]
        ])

        # Define the edge equations of the triangle
        def edge_eq(p1, p2):
            return (x - p1[0]) * (p2[1] - p1[1]) - (y - p1[1]) * (p2[0] - p1[0])

        # Check if the points are within the triangle by checking on the same side of each edge
        mask = (
            (edge_eq(vertices[0], vertices[1]) >= 0) &
            (edge_eq(vertices[1], vertices[2]) >= 0) &
            (edge_eq(vertices[2], vertices[0]) >= 0)
        )

        return mask
    def _oval(self, size=19, r_x=7, r_y=4, x_offset=0, y_offset=0):
        x0 = y0 = size // 2
        x0 += x_offset
        y0 += y_offset
        y, x = np.ogrid[:size, :size]
        y = y[::-1]

        return ((x - x0)**2 / r_x**2 + (y - y0)**2 / r_y**2) <= 1

    def _rectangle(self, size=64, width=7, height=4, x_offset=0, y_offset=0):
        x0 = size // 2 + x_offset
        y0 = size // 2 + y_offset
        half_width = width // 2
        half_height = height // 2
        y, x = np.ogrid[:size, :size]
        y = y[::-1]

        return (np.abs(x - x0) <= half_width) & (np.abs(y - y0) <= half_height)
    def _get_watermarking_pattern(self, gt_init): # in fft space
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
        for i in range(self.w_radius, 0, -1): # from outer circle to inner circle
            if self.shape == "circle":
                tmp_mask = torch.tensor(self._circle_mask(gt_init.shape[-1], r=i)).to(self.device) # circle mask in bool value
            elif self.shape == "square":
                tmp_mask = torch.tensor(self._square_mask(gt_init.shape[-1], side_length=i)).to(self.device) # circle mask in bool value
            elif self.shape == "triangle":
                tmp_mask = torch.tensor(self._triangle(gt_init.shape[-1], side_length=i)).to(self.device) # circle mask in bool value
            elif self.shape == "oval":
                tmp_mask = torch.tensor(self._oval(gt_init.shape[-1], r_x=i, r_y=self.width)).to(self.device) # circle mask in bool value
            elif self.shape == "rectangle":
                tmp_mask = torch.tensor(self._rectangle(gt_init.shape[-1], height=i, width=self.width)).to(self.device) # circle mask in bool value
            gt_patch[:, self.w_channel, tmp_mask] = gt_patch[0, self.w_channel, 0, i].item() # set the value inside the circle to be a value from Gaussian Distribution
        return gt_patch

    def _get_watermarking_mask(self, gt_patch):
        watermarking_mask = torch.zeros(gt_patch.shape, dtype=torch.bool).to(self.device)

        if self.shape == "circle":
            print("circle")
            watermarking_mask[:,self.w_channel] = torch.tensor(self._circle_mask(gt_patch.shape[-1], r=self.w_radius)).to(self.device)
        elif self.shape == "square":
            print("square")
            watermarking_mask[:,self.w_channel] = torch.tensor(self._square_mask(gt_patch.shape[-1], side_length=10)).to(self.device)
        elif self.shape == "triangle":
            print("triangle")
            watermarking_mask[:,self.w_channel] = torch.tensor(self._triangle(gt_patch.shape[-1], side_length=self.w_radius)).to(self.device)
        elif self.shape == "oval":
            print("oval")
            watermarking_mask[:,self.w_channel] = torch.tensor(self._oval(gt_patch.shape[-1], r_x=21, r_y=5)).to(self.device)
        elif self.shape == "rectangle":
            print("rectangle")
            watermarking_mask[:,self.w_channel] = torch.tensor(self._rectangle(gt_patch.shape[-1], height=29, width=10)).to(self.device)
        return watermarking_mask

    def tree_ring_p_value(self, latents):
        target_patch = self.gt_patch[self.watermarking_mask].flatten()
        target_patch = torch.concatenate([target_patch.real, target_patch.imag])

        reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))[self.watermarking_mask].flatten()
        reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag])
        
        sigma_w = reversed_latents_w_fft.std()
        lambda_w = (target_patch ** 2 / sigma_w ** 2).sum().item()
        x_w = (((reversed_latents_w_fft - target_patch) / sigma_w) ** 2).sum().item()
        p_w = scipy.stats.ncx2.cdf(x=x_w, df=len(target_patch), nc=lambda_w)
        return p_w

    def l1_value(self, latents):
        target_patch = self.gt_patch[self.watermarking_mask].flatten()
        target_patch = torch.concatenate([target_patch.real, target_patch.imag])

        reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))[self.watermarking_mask].flatten()
        reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag])
        
        dif = torch.abs(target_patch - reversed_latents_w_fft).sum()
        return dif


#test
tester = GTWatermarkCachedLearnedWM(None, torch.zeros((1, 4, 64, 64)), shape='square')
print(tester._circle_mask(size=64, r=10).sum())
print(tester._square_mask(size=64, side_length=18).sum())
print(tester._oval(64, r_x=21, r_y=5).sum())
print(tester._rectangle(64, width=29, height=10).sum())